# -*- coding:utf-8 -*- 
#Author: OCEAN
#
#                            _ooOoo_
#                           o8888888o
#                           88" . "88
#                           (| -_- |)
#                           O\  =  /O
#                        ____/`---'\____
#                      .'  \\|     |//  `.
#                     /  \\|||  :  |||//  \
#                    /  _||||| -:- |||||-  \
#                    |   | \\\  -  /// |   |
#                    | \_|  ''\---/''  |   |
#                    \  .-\__  `-`  ___/-. /
#                  ___`. .'  /--.--\  `. . __
#               ."" '<  `.___\_<|>_/___.'  >'"".
#              | | :  `- \`.;`\ _ /`;.`/ - ` : | |
#              \  \ `-.   \_ __\ /__ _/   .-` /  /
#         ======`-.____`-.___\_____/___.-`____.-'======
#                            `=---='
#        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#                    Buddha-like Programming......
import os
import os.path

import torch
import torch.utils.data as data
import torchvision
from torch.utils.data import Dataset
from torchvision import transforms

from utils.folder_utils import default_loader

import cv2
import matplotlib.pyplot as plt

root_dir = '/media/disk4/share/widerface/'
train_folder = os.path.join(root_dir, 'Training Images/WIDER_train/images/')
val_folder = os.path.join(root_dir, 'Validation Images/WIDER_val/images/')
test_folder = os.path.join(root_dir, 'Testing Images/WIDER_test/images/')
label_folder = os.path.join(root_dir, 'wider_face_split/')


class WiderFaceDataSet(Dataset):
    def __init__(self,
                 phase='train',
                 transform=None,
                 target_transform=None,
                 loader=default_loader,
                 ):
        self.phase = phase
        self.train_label_file = os.path.join(label_folder, 'wider_face_train_bbx_gt.txt')
        self.val_label_file = os.path.join(label_folder, 'wider_face_val_bbx_gt.txt')
        # print(self.val_label)
        self.test_filelist = os.path.join(label_folder, 'wider_face_test_filelist.txt')
        self.imgs = self.make_items() if self.phase != 'test' else self.get_test_imgs()
        self.loader = loader
        self.transform = transform
        self.target_transform = target_transform

    def get_test_imgs(self):
        img_names = []
        with open(self.test_filelist, 'r+') as f:
            for line in f.readlines():
                img_names.append(line.strip())
        return img_names

    def make_items(self):
        if self.phase == 'train':
            label_file = self.train_label_file
        elif self.phase == 'val':
            label_file = self.val_label_file
        elif self.phase == 'test':
            return {}
        else:
            raise RuntimeError('Phase must be train or val')
        items = []
        # wider_face_train_bbx_gt.txt
        with open(label_file, 'r+') as f:
            lines_iter = iter(f.readlines())
            while True:
                try:
                    img_name = (next(lines_iter).strip())
                    # label_dict[img_name] = []
                    bbx_num = int(next(lines_iter))
                    bbxes = []
                    for i in range(bbx_num):
                        bbx = list(map(int, next(lines_iter).split(' ')[:4]))
                        bbxes.append(bbx)
                        # label_dict[img_name].append(bbx)

                    items.append((img_name, bbxes))
                except StopIteration:
                    break

        return items

    def __getitem__(self, index):
        """
        :param index:
        :return:
        """
        if self.phase == 'train':
            (img_name, bbxes) = self.imgs[index]
            img_path = os.path.join(train_folder, img_name)
        elif self.phase == 'val':
            (img_name, bbxes) = self.imgs[index]
            img_path = os.path.join(val_folder, img_name)
        elif self.phase == 'test':
            img_name = self.imgs[index]
            img_path = os.path.join(test_folder, img_name)
            img = self.loader(img_path)
            if self.transform is not None:
                img = self.transform(img)
            return img
        else:
            raise RuntimeError('Phase error.')
        img = self.loader(img_path)
        assert img is not None

        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            bbxes = self.target_transform(bbxes)

        return img, bbxes

    def __len__(self):
        return len(self.imgs)


def imshow(inp, title=None):
    inp = inp.numpy().transpose((1, 2, 0))
    # mean = np.array([0.485, 0.456, 0.406])
    # std = np.array([0.229, 0.224, 0.225])
    # inp = std * inp + mean
    # inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)


def dataset_test(phase='train'):
    plt.ion()
    train_transform = transforms.Compose([
        transforms.Resize((1024, 1024)),
        transforms.ToTensor()
    ])
    dataset = WiderFaceDataSet(phase=phase, transform=train_transform)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=16,
                                                   shuffle=True, num_workers=4)
    if phase == 'train' or phase == 'val':
        inputs, bbxes = next(iter(dataloader))
    elif phase == 'test':
        inputs = next(iter(dataloader))
    else:
        raise RuntimeError('phase error.')
    out = torchvision.utils.make_grid(inputs)
    imshow(out)
    print('the number of '+phase+' dataset is: '+str(len(dataset)))


if __name__ == '__main__':
    dataset_test('test')
